import numpy as np
from tqdm import tqdm
import os
from sklearn.metrics import pairwise

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

from dataset import CustomDataset
from network import compute_gradients



#################################### Data Generation ####################################

def generate_multivariate_t(df, d, n_samples, loc=None, covmat = None, random_seed=None):
    if random_seed is not None:
        np.random.seed(random_seed)

    if loc is None:
        loc = np.zeros(d)
    if covmat is None:
        covmat = np.eye(d)

    Z = np.random.multivariate_normal(mean=np.zeros(d), cov=covmat, size=n_samples)
    U = np.random.chisquare(df, size=n_samples) / df
    X = loc + Z / np.sqrt(U)[:, None]
    return X


#################################### Evaluation ####################################
def compute_mmd(X, Y, kernel='rbf'):
    if kernel == 'rbf':
        Kxx = pairwise.rbf_kernel(X)
        Kyy = pairwise.rbf_kernel(Y)
        Kxy = pairwise.rbf_kernel(X, Y)
        mmd = Kxx.mean() + Kyy.mean() - 2 * Kxy.mean()
        return mmd



#################################### Sieved Conjuate ####################################
def inner_product(x, y):
    return torch.sum(x * y, dim=tuple(range(1, len(y.shape))))


def compute_convex_conjugate_sieved(y, phi, inner_epoch, Mn):
    x = torch.zeros_like(y, requires_grad=True)  # Initialize x as zero tensor with gradients
    optimizer = optim.SGD([x], lr=0.001) 

    # Closure to compute the objective function and its gradient
    def closure():
        optimizer.zero_grad()
        loss = (phi(x) - inner_product(x, y)).sum()
        loss.backward(retain_graph=True)  # Retain the computation graph
        return loss
    
    # Run the optimizer
    for _ in range(inner_epoch):
        optimizer.step(closure)

        # project them to the ball B(0, Mn)
        norms = torch.norm(x, dim=1)
        scaling_factors = torch.clamp(Mn / norms, max=1.0)
        x.data = x * scaling_factors.unsqueeze(1)

    # Return the optimal value of the convex conjugate
    phi_star_y = (inner_product(x, y) - phi(x))
    
    return phi_star_y


#################################### Train NN ####################################
    
def train_brenier(x, y, batch_size, model,
                learning_rate1=0.005, learning_rate2=0.001, 
                num_epochs1=50, num_epochs2=50,
                num_epochs_convex_conjugate=500,
                sieved = np.inf, 
                bar=False, val_ratio=0.1,
                model_save_path=None, save=False, load=False):
    # If model_save_path != None and save == False, then only read the model.
    
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).float()
    if isinstance(y, np.ndarray):
        y = torch.from_numpy(y).float()
    
    # Create Dataset
    dataset = CustomDataset(x, y)  # see dataset.py

    # Split dataset into training and validation sets
    val_size = int(len(dataset) * val_ratio)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Loss and optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate1)

    # Load model if a path is specified
    if load and os.path.exists(model_save_path):
        model.load_state_dict(torch.load(model_save_path, weights_only=True))
    # else:
    #     save=True  # If there is no model, we automatically save it


    # Progress bar for epochs
    num_epochs = num_epochs1 + num_epochs2
    epoch_progress = tqdm(range(num_epochs), desc='Training Epochs', unit='epoch') if bar else range(num_epochs)
    

    train_loss_record = []
    val_loss_record = []

    for epoch in epoch_progress:
        epoch_loss = 0.0  
        model.train()

        if epoch == num_epochs1 + 1:
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate2)

        # Training phase
        for x_inputs, y_inputs in train_dataloader:
            output = model(x_inputs)
            loss_main = output.mean()
            loss_convex = compute_convex_conjugate_sieved(y_inputs, model, num_epochs_convex_conjugate, sieved).mean()
            loss = loss_main + loss_convex
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_epoch_loss = epoch_loss / len(train_dataloader)

        # Validation Phase
        if np.not_equal(val_ratio, 0):
            # Initialize lists to accumulate all validation inputs
            all_y_hat = []
            all_val_y = []

            # Validation Loss Calculation
            model.eval()
            for val_x_inputs, val_y_inputs in val_dataloader:
                # Compute predicted values (gradients)
                y_hat = compute_gradients(model, val_x_inputs)
                
                # Store the outputs and true values
                all_y_hat.append(y_hat.detach().numpy())
                all_val_y.append(val_y_inputs.detach().numpy())
            
            # Concatenate all batches vertically
            all_y_hat = np.vstack(all_y_hat)  # Shape: (total_samples, num_features)
            all_val_y = np.vstack(all_val_y)

            epoch_val_loss = compute_mmd(all_y_hat, all_val_y)


            # Collect Loss
            train_loss_record.append(avg_epoch_loss)
            val_loss_record.append(epoch_val_loss)

            avg_val_loss = np.mean(val_loss_record)
            if bar:
                epoch_progress.set_postfix({'Mean Loss': avg_epoch_loss, 'Validation Loss': avg_val_loss})

    # Save model before return
    if save and model_save_path is not None:
        torch.save(model.state_dict(), model_save_path)

    return model, train_loss_record, val_loss_record


    

